iT邦幫忙

2024 iThome 鐵人賽

DAY 19
0
AI/ ML & Data

【AI筆記】30天從論文入門到 Pytorch 實戰系列 第 19

【AI筆記】30天從論文入門到 Pytorch 實戰:模型訓練後的保存與載入技巧 Day 18

  • 分享至 

  • xImage
  •  

T2I-Adapter

Save Model

通常我們不會每個 model 都進行 save,而是設定幾次存一次 Model ,在 PyTorch 中,你可以使用 torch.save 函數來保存模型的狀態字典(state_dict)。這是保存模型的推薦方法。以下是保存模型的範例:

設定 Iteration

每個epoch都會抓N個batch去訓練Model,其中這邊的一個iteration就代表抓一個batch去訓練Model
https://github.com/fan84sunny/T2I-Adapter/blob/7056c7fb080afc55ce246b2e8b5d0488977fc72f/train_seg.py#L318
Save Model
所以在設定幾個iteration去存model一次? config['training']['save_freq']在哪裡?
https://github.com/fan84sunny/T2I-Adapter/blob/main/train_seg.py#L206

config = OmegaConf.load(f"{opt.config}")
# 這邊是自己下指令時才會自動去抓你下的config是用哪個?

位址
config['training']['save_freq'] = 1e4
https://ithelp.ithome.com.tw/upload/images/20240819/20168385y4CcPwJmPW.jpg
OmegaConf他會去抓yaml的階層,有階層關係,所以才會變成['training']['save_freq']

# 假設你有一個模型叫做 model
model_ad = ...  # 你的模型定義

# 提取出原始模型
model_ad_bare = get_bare_model(model_ad)

https://github.com/fan84sunny/T2I-Adapter/blob/main/train_seg.py#L324
這段程式碼的目的是從包裝的模型中提取出原始模型。當你使用 DataParallelDistributedDataParallel 來進行多 GPU 訓練時,模型會被包裝在這些類中。這樣做的原因是為了在多個 GPU 上分配和同步模型的參數。

然而,有時候你可能需要訪問原始的模型(即未包裝的模型),例如在保存模型或進行某些操作時。這段程式碼的 get_bare_model 函數就是為了這個目的而設計的。

以下是這段程式碼的詳細解釋:
目的是在保存模型的狀態字典(state_dict)時,移除多餘的 .module 前綴。當你使用 DataParallel 或 DistributedDataParallel 進行多 GPU 訓練時,模型的參數名稱會自動加上 .module 前綴。這段程式碼會移除這個前綴,然後將參數保存到一個新的字典中,並最終保存到指定的路徑。

檢查模型是否被包裝:

  • 如果是平行化運行的話,Model會多出.module,所以要改掉避免loading只能在多GPU環境上面運行。
  • 檢查模型是否被 DataParallelDistributedDataParallel 包裝。如果是,則提取出原始模型(即 net.module)。
if isinstance(net, (DataParallel, DistributedDataParallel)):
    net = net.module

這樣做的好處是,你可以在不使用多 GPU 訓練的情況下,方便地載入和使用這些模型參數。

# 保存模型的狀態字典
torch.save(model.state_dict(), 'model.pth')

# 因為原本訓練在 GPU 上面,要改成cpu 如果環境中只有CPU的狀況下才不會出問題
save_dict[key] = param.cpu()
torch.save(save_dict, save_path)

Load Model

https://github.com/fan84sunny/T2I-Adapter/blob/main/train_seg.py#L267

log

我很少使用到這個

這段程式碼的目的是在訓練過程中自動恢復模型的狀態,包括模型參數、優化器狀態和訓練進度。讓我們逐步解釋這段程式碼:

恢復訓練狀態

load_resume_state 函數

如果 opt.auto_resume 為真,則檢查 experiments 目錄下是否存在訓練狀態文件。如果存在,則找到最新的狀態文件並設置 resume_state_path

if opt.auto_resume:
    state_path = osp.join('experiments', opt.name, 'training_states')
    if osp.isdir(state_path):
        states = list(scandir(state_path, suffix='state', recursive=False, full_path=False))
        if len(states) != 0:
            states = [float(v.split('.state')[0]) for v in states]
            resume_state_path = osp.join(state_path, f'{max(states):.0f}.state')
            opt.resume_state_path = resume_state_path

檢查是否自動恢復

  • state_path = osp.join('experiments', opt.name, 'training_states') 要確定 這個資料夾底下 有.state的文件。
  • 使用 scandir 函數列出目錄中所有以 .state 為後綴的文件,並將結果存儲在 states 列表中。
  • states = [float(v.split('.state')[0]) for v in states]: 將每個狀態文件名中的數字部分提取出來,並轉換為浮點數。例如,'123.state' 會被轉換為 123.0。

    因為這個數字是iteration,要給訓練提取 iteration 到哪裡了

if opt.auto_resume:
    state_path = osp.join('experiments', opt.name, 'training_states')
    if osp.isdir(state_path):
        states = list(scandir(state_path, suffix='state', recursive=False, full_path=False))
        if len(states) != 0:
            states = [float(v.split('.state')[0]) for v in states]
            resume_state_path = osp.join(state_path, f'{max(states):.0f}.state')
            opt.resume_state_path = resume_state_path

簡單的Load Model

直接載入權重就好,不要管其他設定
因為基本上應該不會動...


# 創建一個新的模型實例
adapter_model = Adapter()

# 載入保存的狀態字典
adapter_model.load_state_dict(torch.load('adapter_model.pth'))

結論

在這篇文章中,我們探討了多個與模型訓練和保存相關的主題,並詳細解釋了如何在 PyTorch 中保存和載入模型。我們首先介紹了如何使用 torch.save 和 torch.load 函數來保存和載入模型的狀態字典,並提供了完整的範例代碼。接著,我們討論了如何從包裝的模型中提取出原始模型,這在使用 DataParallel 或 DistributedDataParallel 進行多 GPU 訓練時尤為重要。

我們還深入分析了如何在訓練過程中自動恢復模型的狀態,包括模型參數、優化器狀態和訓練進度。通過檢查目錄中的狀態文件並找到最新的狀態文件,我們可以方便地從上次中斷的地方繼續訓練,這對於長時間訓練的模型特別有用。
這些技巧和方法可以幫助我們更有效地管理和恢復模型的訓練狀態,從而提高訓練效率和模型性能。


上一篇
【AI筆記】30天從論文入門到 Pytorch 實戰:準備與整理Dataset Day 17
下一篇
【AI筆記】30天從論文入門到 Pytorch 實戰:全面評估你的模型表現 Day 19
系列文
【AI筆記】30天從論文入門到 Pytorch 實戰26
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言